import torch
import torch.nn as nn
import numpy as np
import cv2 as cv
from model.shufflenet import ShuffleNet
from model.fpn import ZhnNetFpn
from model.head import ZhnNetHead


class ZhnNet(nn.Module):
    def __init__(self, train_backbone=False, device=torch.device('cpu')):
        super().__init__()
        self.backbone = ShuffleNet()
        self.fpn = ZhnNetFpn()
        self.head = ZhnNetHead(train_backbone, device)

    def forward(self, x):
        x = self.backbone(x)
        x = self.fpn(x)
        x = self.head(x)
        return x


def image_generate_conf(img, label):
    """在图像上显示结果
    :param img:网络输入的原图像
    :param label:网络经解码后的输出
    """
    img = cv.putText(img, '{:.4f}'.format(label), (160, 160), cv.FONT_HERSHEY_PLAIN, 1.2, (0, 0, 255), 1)
    return img


def image_generate_loc(img, label):
    """在图像上显示结果
    :param img:网络输入的原图像
    :param label:网络经解码后的输出
    """
    label = label.view(-1, 5).numpy()
    index = label[:, 4].argmax()
    pred = label[index]
    if pred[4] < 0.5:
        return img
    corner = pred[:4].copy()
    corner[0] = pred[0] - pred[2]/2
    corner[1] = pred[1] - pred[3]/2
    corner[2] = pred[0] + pred[2]/2
    corner[3] = pred[1] + pred[3]/2
    corner = np.round(corner*128).astype('int32')
    img = cv.rectangle(img, (corner[0], corner[1]), (corner[2], corner[3]), color=(0, 0, 255), thickness=2)
    pred_conf = '{:.4f}'.format(pred[4])
    img = cv.putText(img, pred_conf, (corner[0], corner[1]), cv.FONT_HERSHEY_PLAIN, 1.2, (0, 0, 255), 1)
    return img
